Star classification project

image.jpg

By Spacecraft: ESA/ATG medialab; Milky Way: ESA/Gaia/DPAC. Acknowledgement: A. Moitinho., CC BY-SA IGO 3.0, CC BY-SA 3.0 igo¶

The goal of this project is to classify stars from a star catalogue. For this project, we will use some stars fro the Gaia space telescope from the European Space Agency. This telescope was launched in 2013 and were designed for astrometry. Astrometry consists in measuring the position, distance, magnitude and other properties from stars.

Stellar classification is the classification of stars based on their sepctral characteristics. Theire are several different classification systems, but the most used one is th Morgan-Keenan (MK) system. This system uses letters: image.png

By Pablo Carlos Budassi - Own work, CC BY-SA 4.0

You can find more information about stellar classification on Wikipedia. The dataset we'll use is from kaggle. It consist on a subset of the data from Gaia (hopefully we don't work with all the objects captured by the telescope, because they are 1800 million !)

In [ ]:
# general imports 
import polars as pl 
import pandas as pd 
import numpy as np
import itertools
import warnings


# Plotting tools import 
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import matplotlib.colors as mpl_colors

# Astronomy imports 
from astropy.coordinates import SkyCoord
from astropy import units as u

# ML imports 
from sklearn.preprocessing import LabelEncoder

# Filter warnings 
warnings.filterwarnings('ignore', category=matplotlib.MatplotlibDeprecationWarning)

1. Explorative data analysis

First let's import our dataset into a polars dataframe. Polars is an alternative for pandas, writen in Rust. The performance of this library are generally better then polars, and the queries on the dataframe are close to SQL queries. Check the library !

In [ ]:
df = pl.read_csv('/kaggle/input/gaia-stars-dataset-from-dr3-data-release-3/dataGaia2.csv')
print(f'The dataframe has {df.shape[1]} columns and {df.shape[0]} rows')
df.head()
The dataframe has 50 columns and 626016 rows
Out[ ]:
shape: (5, 50)
RA_ICRSDE_ICRSSourcee_RA_ICRSe_DE_ICRSPlxe_PlxPMpmRAe_pmRApmDEe_pmDERUWEGmage_GmagBPmage_BPmagRPmage_RPmagGRVSmage_GRVSmagRVlogg[Fe/H]DistPQSOPGalPstarPWDPbinTeffA0AGABPARPE(BP-RP)GMAGRadRad-FlameLum-FlameMass-FlameAge-Flamez-FlameEvolSpType-ELSFlags-HSEWHae_EWHaf_EWHa
i64f64f64i64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64f64strf64f64stri64f64f64i64
044.5890122.19529813063615483605760.06550.06120.23840.07692.9012.0880.081-2.0140.0731.07916.9085370.00293116.7614350.01227917.134040.018818nullnullnull4.5728-0.975919867.7480.00.0000010.000050.9999480.000000518148.6110.00760.00760.00880.00470.00410.38941.66351.7096298.48514nullnull1.5072482null"O "920.021280.193090
135.3580358.988813237002866699715840.06580.07170.16660.07854.402-0.2420.085-4.3960.0810.9416.9621430.00297416.8411730.00581817.1938550.010486nullnullnull4.4948-1.166121021.8960.00.0000020.000050.9999480.000000517500.2360.00710.00710.00820.00440.00380.35481.7471.8249292.6721nullnull1.3329091null"O "920.021440.1760
244.45076710.079118271098378679957760.06270.05170.35440.07273.1542.7220.0741.5930.0641.0616.4074940.00293816.3824040.00683516.4295980.010283nullnullnull4.5526-0.991914943.4340.01.0000e-80.000050.999950.000000519761.3630.50450.50010.5830.31340.26960.01621.84391.9063506.7328nullnull1.5823381null"O "920.137260.133640
348.40490915.105912310097712521867520.050.04210.59620.05572.7452.460.0571.2180.0511.05215.6071310.00296415.49760.00925615.6723520.015623nullnullnull4.5211-0.939211625.3320.00.00.9999980.0000020.019486.80.23420.23430.27180.14570.1260.02681.86040.25799.055018nullnull1.5618445null"O "930.09470.105480
457.09283811.550927368760093853003520.05210.03350.45070.05754.918-2.8510.07-4.0080.0451.05116.3727380.00284516.35490.00459716.384790.00668nullnullnull4.5807-0.996512459.0440.01.0000e-80.000050.999950.000000518453.3460.52650.51930.60620.32670.27960.36261.66521.722311.50284nullnull1.5217344null"O "920.086820.078310

There are a lot of columns, we won't keep them all.

There are some more detailed information on the data on the Gaia website

We will remove some of these columns from the start of this work because they are not relevant in this study and we won't use them in any statistics. We also do some basic preprocessing to be able to work with the data:

  • Strip the SpType-ELS column
  • Change the datatype of the column Age-Flame from string to float.

Here are all the columns

In [ ]:
for column in df.columns:
    print(column)
RA_ICRS
DE_ICRS
Source
e_RA_ICRS
e_DE_ICRS
Plx
e_Plx
PM
pmRA
e_pmRA
pmDE
e_pmDE
RUWE
Gmag
e_Gmag
BPmag
e_BPmag
RPmag
e_RPmag
GRVSmag
e_GRVSmag
RV
logg
[Fe/H]
Dist
PQSO
PGal
Pstar
PWD
Pbin
Teff
A0
AG
ABP
ARP
E(BP-RP)
GMAG
Rad
Rad-Flame
Lum-Flame
Mass-Flame
Age-Flame
z-Flame
Evol
SpType-ELS
Flags-HS
EWHa
e_EWHa
f_EWHa
In [ ]:
to_drop = [
    'Source',
    'e_RA_ICRS',
    'e_DE_ICRS',
    'Plx',
    'e_Plx',
    'PM',
    'pmRA',
    'e_pmRA',
    'pmDE',
    'e_pmDE',
    'RUWE',
    'e_Gmag',
    'e_BPmag',
    'e_RPmag',
    'GRVSmag',
    'e_GRVSmag',
    'RV',
    'PQSO',
    'PGal',
    'Pstar',
    'PWD',
    'Pbin',
    'A0',
    'E(BP-RP)',
    'Flags-HS',
    'EWHa',
    'e_EWHa',
    'f_EWHa'
]
df = df.drop(to_drop)

# strip the stellar types
df = df.with_columns(pl.col('SpType-ELS').str.strip_chars())

# Change the data type for Age-Flame
df = df.with_columns(pl.col('Age-Flame').cast(pl.Float64))

print(df.shape)
(626016, 22)

Here is a description of the remaining columns:

  • RA_ICRS: Right ascension in the ICRS (International Celestial Reference System) coordinate system
  • DE_ICRS: Declination in the ICRS coordinate system
  • Gmag: Average apparent magnitude integrated in the G band
  • BPmag: Average apparent magnitude integrated in the BP blue band
  • RPmag: Average apparent magnitude integrated in the RP red band
  • logg: Surface gravity
  • Dist: Distance to the celestial object: inverse of the parallax, in parsecs
  • Teff: Estimated effective temperature of the celestial object by Gaia in Kelvins
  • GMAG: Absolute Gmag estimated from Gaia
  • AG: Extintion in G band
  • ABP: Extintion in BP band
  • ARP: Extintion in RP band
  • Rad: Object radius estimate in terms of solar radius
  • Lum-Flame: Estimated object luminosity in terms of solar luminosity
  • Mass-Flame: Mass estimate in terms of solar mass
  • Age-Flame: Celestial object age in giga years
  • z-Flame: Redshift in km/s
  • Evol: Evol stage
  • SpType-ELS: Estimated spectral class by Gaia

First, let's take a brief look at the coordinates of our stars. The next plot is a Aitoff projection, commonly used in astronomy, of the position of the stars. We give each star type a color.

In [ ]:
coords = SkyCoord(ra=df['RA_ICRS'].to_numpy(), dec=df['DE_ICRS'].to_numpy(), unit='deg')
ra_rad = coords.ra.wrap_at(180 * u.deg).radian
dec_rad = coords.dec.radian

unique_spectral_types = np.sort(df['SpType-ELS'].unique().to_numpy())

color_map = plt.cm.get_cmap('Spectral_r', len(unique_spectral_types))
spectral_type_to_color = {stype: color_map(i) for i, stype in enumerate(unique_spectral_types)}
colors = [spectral_type_to_color[stype] for stype in df['SpType-ELS'].to_numpy()]
my_colors = list(map(lambda x: x[1], sorted([(k, v) for k, v in spectral_type_to_color.items()], key=lambda x: x[0])))

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='aitoff')
scatter = ax.scatter(ra_rad, dec_rad, s=0.2, c=colors, alpha=0.2)

for stype, color in spectral_type_to_color.items():
    ax.scatter([], [], c=[color], label=stype)

plt.title('Aitoff Projection of Stars Colored by Spectral Type \n')
plt.xlabel('Right Ascension (RA)')
plt.ylabel('Declination (DE)')
plt.grid(True)
plt.legend(loc='lower left', bbox_to_anchor=(1.05, 0), title='Spectral Type', markerscale=3)
plt.show()
In [ ]:
fig, axes = plt.subplots(nrows=(len(unique_spectral_types) + 2) // 3, ncols=3, figsize=(15, 10), subplot_kw={'projection': 'aitoff'})
fig.subplots_adjust(hspace=0.3, wspace=0.3)

axes = axes.flatten()

for i, spectral_type in enumerate(unique_spectral_types):
    
    plot_df = df.filter(pl.col('SpType-ELS') == spectral_type)
    coords = SkyCoord(ra=plot_df['RA_ICRS'].to_numpy(), dec=plot_df['DE_ICRS'].to_numpy(), unit='deg')
    ra_rad = coords.ra.wrap_at(180 * u.deg).radian
    dec_rad = coords.dec.radian
    
    axes[i].hexbin(ra_rad, dec_rad, gridsize=45, bins='log')
    axes[i].set_title(f'Hexbin of {spectral_type.strip()} stars \n ')
    axes[i].grid(True)

for i in range(len(unique_spectral_types), len(axes)):
    axes[i].axis('off')

The repartition of the stars in the sky is not equaly distrubited, but this should not affect our study.

Let's take a look to our response variable, SpType-ELS

In [ ]:
plot_df = (
    df
    .group_by(pl.col('SpType-ELS'))
    .agg(pl.len().alias('Count'))
    .sort('SpType-ELS')
)

plot = sns.barplot(x=plot_df['SpType-ELS'].to_numpy(), 
                   y=plot_df['Count'].to_numpy(), 
                   palette='Spectral_r', 
                   saturation=0.9
                  )
plot.set_title('Count of stars by type')
plot.set_xlabel('Stellar classification')
plot.set_ylabel('Count')
Out[ ]:
Text(0, 0.5, 'Count')

we have unbalanced data. We will take care of that later.

Another interesting thing with stars is the HR diagram (Hertzsprung-Russell diagram). The diagrams shows the relationship between stars magnitude and effective temperature. More information about this diagram on Wikipedia.

In [ ]:
temperature = df['Teff'].to_numpy() 
luminosity = df['Lum-Flame'].to_numpy()  


plt.figure(figsize=(10, 6))
plt.scatter(temperature, 
            luminosity, 
            c=colors, 
            s=0.2,
            alpha=0.5, 
            edgecolor='none')

plt.xscale('log')
plt.yscale('log')
plt.xlabel('Inverse Effective Temperature (K)')
plt.ylabel('Luminosity (L/Lsun)')
plt.title('Hertzsprung-Russell Diagram')

for stype, color in spectral_type_to_color.items():
    plt.scatter([], [], c=[color], label=stype)

plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left", title='Spectral Type', markerscale=3)
plt.grid(True)
plt.gca().invert_xaxis()
plt.show()

This graphic is very interesting and already show some differences between the stars given their classification. Let's look to this diagram for each spectral type

In [ ]:
fig, axes = plt.subplots(nrows=(len(unique_spectral_types) + 2) // 3, ncols=3, figsize=(15, 10))
fig.subplots_adjust(hspace=0.5, wspace=0.5)

axes = axes.flatten()

x_min = np.nanmin(temperature)
x_max = np.nanmax(temperature)
y_min = np.nanmin(luminosity)
y_max = np.nanmax(luminosity)

for i, j in enumerate(spectral_type_to_color.items()):  
    spectral_type, color = j
    plot_df = df.filter(pl.col('SpType-ELS') == spectral_type)
    temp = plot_df['Teff'].to_numpy()  
    lum = plot_df['Lum-Flame'].to_numpy()  
    
    ax = axes[i]
    sc = ax.scatter(temp, 
                    lum,
                    c=[color],
                    s=1,
                    alpha=0.5, 
                    edgecolor='none'
                   )
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.invert_xaxis()
    ax.set_xlabel('Inverse Effective Temperature (K)')
    ax.set_ylabel('Luminosity (L/Lsun)')
    ax.set_title(f'HR Diagram for {spectral_type}')
    ax.grid(True)

# Hide any unused subplots
for i in range(len(unique_spectral_types), len(axes)):
    axes[i].axis('off')
    
plt.savefig('HR_cat.png', transparent=True)

Now let's take a look to the radius as a function of the temperature. The radius is expressed as a ratio of the sun radius

In [ ]:
plt.figure(figsize=(10, 6))

plt.scatter(df['Teff'].to_numpy(), 
            df['Rad'].to_numpy(), 
            c=colors, 
            alpha=0.5,
            s=0.2,
            edgecolor='none'
           )

plt.yscale('log')
plt.xlabel('Effective Temperature (K)')
plt.ylabel('Stellar radius (R/R☉)')
plt.title('Radius vs. Temperature')

for stype, color in spectral_type_to_color.items():
    plt.scatter([], [], c=[color], label=stype)

plt.legend(loc='upper right', title='Spectral Type', markerscale=3)
plt.grid(True)
plt.savefig('Temp_radius.png', transparent=True)

These plots are classical astronomic plots. Let's now inspect some of the features of the dataset. We will process by group of attributes.

In [ ]:
def generate_data(df,
                 feature_name
                 ):
    data = (
        df
        .select(['SpType-ELS', feature_name])
        .drop_nulls(feature_name)
        .group_by(pl.col('SpType-ELS'))
        .agg(pl.col(feature_name))
        .sort(pl.col('SpType-ELS'))
    )[feature_name].to_numpy()
    
    s_types = df['SpType-ELS'].unique().sort().to_numpy().tolist()
    x = [''] + s_types
    
    return (x, data)

def format_ax(ax, 
              title, 
              xaxis_title, 
              yaxis_title, 
              log_scale
             ):
    
    if log_scale:
        ax.set_yscale('log')
    if title:
        ax.set_title(title)
    if xaxis_title:
        ax.set_xlabel(xaxis_title)
    if yaxis_title:
        ax.set_ylabel(yaxis_title)
    
    return ax

def boxplot_feature(df, 
                    feature_name, 
                    title = None, 
                    xaxis_title = None,                    
                    yaxis_title = None, 
                    log_scale = False, 
                    ax=None
                   ):
    
    x, data = generate_data(df, feature_name)
    
    box = ax.boxplot(data, 
                      notch=False, 
                      patch_artist=True,
                      flierprops=dict(markersize=2), 
                      medianprops=dict(color='black', linewidth=2.5)
                     )

    ax = format_ax(ax=ax, 
                  title=title, 
                  xaxis_title=xaxis_title, 
                  yaxis_title=yaxis_title, 
                  log_scale=log_scale)
        
    for patch, color in zip(box['boxes'], my_colors):
        patch.set_facecolor(color)
        
    plt.xticks(range(len(x)), x)
    return box


def violinplot_feature(df, 
                      feature_name, 
                      title = None, 
                      xaxis_title = None,                    
                      yaxis_title = None, 
                      log_scale = False, 
                      ax=None
                     ):

    x, data = generate_data(df, feature_name)

    violin = ax.violinplot(data, 
                          showmedians=True,
                          widths=0.9,
                          bw_method='scott',
                          )
   

    for pc, color in zip(violin['bodies'], my_colors):
        pc.set_facecolor(color)        

    plt.xticks(range(len(x)), x)

    return violin
In [ ]:
def plot_subplots(data_to_plot, df, plot_func=boxplot_feature):
    s_types = df['SpType-ELS'].unique().sort().to_numpy().tolist()
    x = [''] + s_types
    
    nrows = len(data_to_plot) // 3 + (len(data_to_plot) % 3 > 0)

    fig, axs = plt.subplots(nrows=nrows, ncols=3, figsize=(20, 10))
    fig.subplots_adjust(hspace=0.5, wspace=0.5)
    
    axes = axs.flatten()

    for i, d in enumerate(data_to_plot):
        row = i // 3
        col = i % 3

        my_plot = plot_func(df, **d, ax=axs[row, col])
        axs[row, col].set_title(d['title'])
        axs[row, col].set_xticks(range(len(x)), x)
    
    for i in range(len(data_to_plot), len(axes)):
        axes[i].axis('off')


    return my_plot

Let's start with the common properties:

  • Teff: Estimated effective temperature of the celestial object by Gaia in Kelvins.
  • Dist: Distance to the celestial object: inverse of the parallax, in parsecs.
  • Rad: Object radius estimate in terms of solar radius.
  • Lum-Flame: Estimated object luminosity in terms of solar luminosity.
  • Mass-Flame: Mass estimate in terms of solar mass.
  • Age-Flame: Celestial object age in giga years.
In [ ]:
data_to_plot = [
    dict(
        feature_name='Teff', 
        title = 'Estimated effective temperature', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Temperature (K)', 
        log_scale = True
    ),
    dict(
        feature_name = 'Dist', 
        title = 'Distance to star, inverse of the parallax', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Inverse of parallax (pc)', 
        log_scale = True
    ), 
    dict(
        feature_name = 'Rad', 
        title = 'Dtellar radius in terms of solar radius', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Radius (R/R☉)', 
        log_scale = True
    ), 
    dict(
        feature_name = 'Lum-Flame', 
        title = 'Estimated luminosity in terms of solar luminosity', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Luminosity (L/L☉)', 
        log_scale = True
    ),
    dict(
        feature_name = 'Mass-Flame', 
        title = 'Estimated mass in terms of solar mass', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Mass (M/M☉)', 
        log_scale = True
    ),
    dict(
        feature_name = 'Age-Flame', 
        title = 'Estimated age in terms of solar age', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Age (A/A☉)', 
        log_scale = False
    ),
    dict(
        feature_name = '[Fe/H]', 
        title= 'Metallicity',
        xaxis_title='Stellar type', 
        yaxis_title = '[FE/H] (dex)',
        log_scale=False        
    ),
    dict(
        feature_name = 'logg', 
        title= 'Surface gravity',
        xaxis_title='Stellar type', 
        yaxis_title = 'log(g) $(cm/s^2)$',
        log_scale=False        
    ),
    dict(
        feature_name = 'GMAG',
        title = 'Absolute Gmag estimated from Gaia',
        xaxis_title = 'Stellar type',
        yaxis_title='Absolute magnitude',
        log_scale=False,
    )
]

plot = plot_subplots(data_to_plot, df)
plt.savefig('stats_1.png', transparent=True)

We can already spot some differences between the stars types in these boxplots. We'll perform a correlation analysis later. We can see that we don't have any data for the age of the O type stars, so we won't take this feature into consideration in the analysis. Regarding the distance, we can see that the minimums ($-3\sigma$) and maximums ($+3\sigma$) are very far away for each types, so we won't use this feature as well. Additionaly the distance of a star from earth is not an indicator of the type of star (domain knowledge). We can already say the the class of class M will be the challenging ones.

Let's perform the same analysis on the magnitude features (more info on manitude on Wikipedia:

  • Gmag: Average apparent magnitude integrated in the G band.
    
  • BPmag: Average apparent magnitude integrated in the BP blue band.
    
  • RPmag: Average apparent magnitude integrated in the RP red band.
    
  • 'GMAG': Absolute Gmag estimated from Gaia.
    
  • 'A0': Extintion in A0 line.
    
  • 'AG': Extintion in G band.
    
  • 'ABP': Extintion in BP band.
    
  • 'ARP': Extintion in RP band.
    
    

image.png CactiStaccingCrane, CC0, via Wikimedia Commons

Because the magnitures in this dataset is the apparent magnitude (the magnitude we see and not the magnitude of the star) we need to calculate the absolute magnitudes first. In the mean time, we will calculate the color indices and the corrected color indices.

In [ ]:
df = (
    df
    .with_columns(
        Gmag_abs=pl.col('Gmag') - 5 * np.log10(pl.col('Dist')) + 5 - pl.col('AG'), 
        BPmag_abs=pl.col('BPmag') - 5 * np.log10(pl.col('Dist')) + 5 - pl.col('ABP'),
        RPmag_abs=pl.col('RPmag') - 5 * np.log10(pl.col('Dist')) + 5 - pl.col('ARP'), 
    )
    .with_columns(
        BP_RP_abs = pl.col('BPmag_abs') - pl.col('RPmag_abs'),
        BP_G_abs = pl.col('BPmag_abs') - pl.col('Gmag_abs'),
        G_RP_abs = pl.col('Gmag_abs') - pl.col('RPmag_abs')
    )
)
In [ ]:
data_to_plot = [
    dict(
        feature_name='Gmag', 
        title = 'Apparent magnitude in the G band', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Magnitude', 
        log_scale = True
    ),
    dict(
        feature_name = 'BPmag', 
        title = 'Average apparent magnitude in the BP band', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Magnitude', 
        log_scale = True
    ), 
    dict(
        feature_name = 'RPmag', 
        title = 'Average apparent magnitude in the RP band', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Magnitude', 
        log_scale = True
    ), 
    dict(
        feature_name='Gmag_abs', 
        title = 'Absolute magnitude in the G band', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Magnitude', 
        log_scale = True
    ),
    dict(
        feature_name = 'BPmag_abs', 
        title = 'Absolute magnitude in the BP band', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Magnitude', 
        log_scale = True
    ), 
    dict(
        feature_name = 'RPmag_abs', 
        title = 'Absolute magnitude in the RP band', 
        xaxis_title = 'Stellar type',                    
        yaxis_title = 'Magnitude', 
        log_scale = True
    ), 

]
plot = plot_subplots(data_to_plot, df, violinplot_feature)

The camera of the Gaia telescope are black and white cameras. To obtain the color of the the objects captured by the telescope, we use filter that let only some specific wavelength to pass and then combine the value of each filters. Here is the detail of the filters (G, BP, RP) filters:

image.png

source

If we want to use the colors as predictors, we need to all G, BP and RP data. We have to choose between relative or absolute magnitude, because they are linearly dependant. Let's check how each filters are related:

In [ ]:
color_features = ['Gmag', 'RPmag', 'BPmag']
colors_couples = list(itertools.combinations(color_features, 2))

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(20, 10))
fig.subplots_adjust(hspace=0.5, wspace=0.5)

for i, d in enumerate(colors_couples):
    
    x, y = d
    axs[0, i].scatter(df[x],
                      df[y], 
                      c=colors,
                      s=0.2, 
                      edgecolor=None, 
                      alpha=0.5
                     )
    axs[0, i].set_xlabel(x)
    axs[0, i].set_ylabel(y)
    axs[0, i].set_title(f'{x} - {y} relative magnitude')
        
for i, d in enumerate(colors_couples):
    x, y = d
    axs[1, i].scatter(df[x + '_abs'],
                      df[y + '_abs'], 
                      c=colors,
                      s=0.2, 
                      edgecolor=None, 
                      alpha=0.5
                     )
    axs[1, i].set_xlabel(x + '_abs')
    axs[1, i].set_ylabel(y + '_abs')
    axs[1, i].set_title(f'{x} - {y} aboslute magnitude')


plt.savefig('colors.png', transparent=True)
plt.show()

2. Features selection and data cleaning

Now we will make a selection of features we want to use for the classification. I choosed to keep the following features:

  • Teff: Estimated effective temperature of the celestial object by Gaia in Kelvins.
  • Dist: Distance to the celestial object: inverse of the parallax, in parsecs.
  • Rad: Object radius estimate in terms of solar radius.
  • Lum-Flame: Estimated object luminosity in terms of solar luminosity.
  • Mass-Flame: Mass estimate in terms of solar mass.
  • Gmag_abs: The magnitude in the G filter
  • RPmag_abs: the absolute magnitude in th RP filter
  • BPmag_abs: the aboslute magnitude int the RP filter
  • logg_abs: The aboslute surface gravity

I will also keep the column SpType-ELS (the spectral type) as it is our target value.

In [ ]:
col_to_keep = [
    'SpType-ELS',
    'Teff', 
    'logg',
    'Dist',
    'Rad',
    'Lum-Flame', 
    'Mass-Flame',
    'Gmag_abs', 
    'RPmag_abs',
    'BPmag_abs',   
]
df_work = df.select(col_to_keep)

let's check the number of null count for each column:

In [ ]:
(
    df_work
    .group_by('SpType-ELS')
    .agg(pl.all().is_null().sum())
)
Out[ ]:
shape: (7, 10)
SpType-ELSTeffloggDistRadLum-FlameMass-FlameGmag_absRPmag_absBPmag_abs
stru32u32u32u32u32u32u32u32u32
"A"000000000
"B"0000439218289000
"O"00001033116931000
"G"000000000
"F"000000000
"M"000000101
"K"000000000

The null counts for the O type are a bit problematic. As we have only $26'016$ stars of type O we can't use the Age-Flame feature. We are missing also a lot of data for the Lum-Flame and Mass-Flame. We will take a more deep look into these two features. We also have some missing data in the same features for the stars of class M. As we have a lot of data for this stellar type and we will pertorm down sampling, we'll just drop the missing values for this stellar type.

In [ ]:
plt.figure(figsize=(10, 6))
plt_df = df_work.filter(pl.col('SpType-ELS') == 'O')

plt.hist(plt_df['Mass-Flame'], 
         bins=50, 
         alpha=0.9, 
         density=True, 
         color=spectral_type_to_color['O']
        )
plt.axvline(plt_df['Mass-Flame'].median(), 
            color='red', 
            linestyle='--', 
            label='Median'
           )
plt.axvline(plt_df['Mass-Flame'].mean(), 
            color='green', 
            linestyle='-', 
            label='Mean'
           )
plt.axvline(plt_df['Mass-Flame'].quantile(0.25), 
            color='blue', 
            linestyle=':', 
            label='Q1'
           )
plt.axvline(plt_df['Mass-Flame'].quantile(0.75), 
            color='orange', 
            linestyle=':', 
            label='Q3'
           )
plt.title('Mass relative to solar mass for O type stars')
plt.xlabel('Mass (M/M☉)')
plt.legend()
plt.savefig('M_MS_O.png', transparent=True)
In [ ]:
plt.figure(figsize=(10, 6))
plt_df = df_work.filter(pl.col('SpType-ELS') == 'O')

plt.hist(plt_df['Lum-Flame'], 
         bins=50, 
         alpha=0.9, 
         density=True, 
         color=spectral_type_to_color['O']
        )
plt.axvline(plt_df['Lum-Flame'].median(), 
            color='red', 
            linestyle='--', 
            label='Median'
           )
plt.axvline(plt_df['Lum-Flame'].mean(), 
            color='green', 
            linestyle='-', 
            label='Mean'
           )
plt.axvline(plt_df['Lum-Flame'].quantile(0.25), 
            color='blue', 
            linestyle=':', 
            label='Q1'
           )
plt.axvline(plt_df['Lum-Flame'].quantile(0.75), 
            color='orange', 
            linestyle=':', 
            label='Q3'
           )
plt.title('Luminosity relative to solar Luminosity for O type stars')
plt.xlabel('Luminosity (L/L☉)')
plt.legend()
plt.savefig('L_LS_O.png', transparent=True)

Given these plots, I choosed to replace the null values with the following approach:

  • Mass of O stars: replace null by the median value
  • Luminosity of the O stars: replace null by the mean
In [ ]:
mass_median = df_work.filter(pl.col('SpType-ELS') == 'O').select(pl.col('Mass-Flame')).median()
lum_mean = df_work.filter(pl.col('SpType-ELS') == 'O').select(pl.col('Lum-Flame')).mean() 

df_work = (
    df_work
    .with_columns(
        pl.when((pl.col('Mass-Flame').is_null()) & (pl.col('SpType-ELS') == 'O'))
        .then(pl.col('Mass-Flame').fill_null(mass_median))
        .otherwise(pl.col('Mass-Flame')),
                
        pl.when((pl.col('Lum-Flame').is_null()) & (pl.col('SpType-ELS') == 'O'))
        .then(pl.col('Lum-Flame').fill_null(lum_mean))
        .otherwise(pl.col('Lum-Flame'))
    )
    .drop_nulls()
)
In [ ]:
(
    df_work
    .group_by('SpType-ELS')
    .agg(pl.all().is_null().sum())
)
Out[ ]:
shape: (7, 10)
SpType-ELSTeffloggDistRadLum-FlameMass-FlameGmag_absRPmag_absBPmag_abs
stru32u32u32u32u32u32u32u32u32
"B"000000000
"A"000000000
"G"000000000
"O"000000000
"M"000000000
"F"000000000
"K"000000000

Now we don't have any more null values !

In [ ]:
label_encoder = LabelEncoder()
label_encoder.fit(df_work['SpType-ELS'])

df_work = (
    df_work
    .with_columns(
        pl.col('SpType-ELS').map_batches(label_encoder.transform)
    )
).select(pl.all().shuffle(seed=1))
In [ ]:
plot_df = (
    df_work
    .group_by(pl.col('SpType-ELS'))
    .agg(pl.len().alias('Count'))
    .sort('SpType-ELS')
)

plot = sns.barplot(x=plot_df['SpType-ELS'].to_numpy(), 
                   y=plot_df['Count'].to_numpy(), 
                   palette='Spectral_r', 
                   saturation=0.9
                  )
plot.set_title('Count of stars by type')
plot.set_xlabel('Stellar classification')
plot.set_ylabel('Count')
fig = plot.get_figure()
fig.savefig('Counts.png', transparent=True)

Let's check the correlation matrix and the pairplot.

In [ ]:
corr_df = df_work.corr().select(pl.all().round(3))
corr_matrix = corr_df.to_numpy()
corr_matrix[np.triu_indices(corr_matrix.shape[0], 1)] = np.nan 
cmap = plt.get_cmap('RdBu_r')

fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(corr_matrix, cmap=cmap, vmin=-1, vmax=1)
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8)

ax.set_title('Correlation Matrix')
ax.set_xlabel('Features')
ax.set_ylabel('Features')

ax.set_xticks(np.arange(len(corr_matrix)))
ax.set_yticks(np.arange(len(corr_matrix)))
ax.set_xticklabels(corr_df.columns)
ax.set_yticklabels(corr_df.columns)
ax.spines[:].set_visible(False)

plt.setp(ax.get_xticklabels(), 
         rotation=45, 
         ha="right", 
         rotation_mode="anchor")

for i in range(len(corr_matrix)):
    for j in range(len(corr_matrix)):
        if i < j :
            continue
        if abs(corr_matrix[i, j]) >= 0.6:
            color = 'white'
        else:
            color = 'black'
        text = ax.text(j, 
                       i, 
                       corr_matrix[i, j], 
                       ha="center", 
                       va="center", 
                       color=color)
plt.savefig('correlation_matrix.png', transparent=True)
plt.show()

As explained the magnitudes are correlated, but we need all 3 features to express the color.

In [ ]:
features = corr_df.columns[1:]
correlations = corr_df['SpType-ELS'].to_numpy()[1:]

cmap = plt.get_cmap('RdBu_r')
norm = mpl_colors.Normalize(vmin=-0.5, vmax=0.5)


fig, ax = plt.subplots(figsize=(16, 6))

ax.bar(features, 
       correlations, 
       color=[cmap(norm(correlation)) for correlation in correlations]
      )

ax.set_title('Correlation between target and features')
ax.set_xlabel('Features')
ax.set_ylabel('Correlation')

plt.savefig('Correlation.png', transparent=True)
plt.show()

3. Models

In [ ]:
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.decomposition import PCA
from sklearn.tree import DecisionTreeClassifier
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

Before we start the analysis, we need to split our dataset into train and test dataset

In [ ]:
df_model = df_work.with_columns(pl.col('SpType-ELS').map_batches(label_encoder.inverse_transform))

X = df_model.select(pl.all().exclude('SpType-ELS')).to_numpy()
y = df_model['SpType-ELS']

scaled_X = StandardScaler().fit_transform(X)
In [ ]:
pca = PCA()
pca.fit(scaled_X)

fig, axs = plt.subplots(1, 2, figsize=(16, 5))

axs[0].bar(range(pca.n_components_), pca.explained_variance_ratio_)
axs[0].set_xlabel('Principal Component number')
axs[0].set_title('Explained Variance Ratio')

axs[1].plot(np.cumsum(pca.explained_variance_ratio_))
axs[1].set_title('Cumulative Explained Variance Ratio')

plt.savefig('PCA.png', transparent=True)
plt.show()

With 3 components the cumulative explained variance ratio is about 0.9, hence we perform PCA with 3 components

In [ ]:
pca = PCA(n_components=3, random_state=42)
pca.fit(scaled_X)
pca_X = pca.transform(scaled_X)

We also need to define a method to find the correct label mapping for the result as we are working with unsupervised models

In [ ]:
def map_labels_pl(ytrue: np.array, 
                  ypred: np.array, 
                 ) -> tuple[dict, float]:

    d = dict(ytrue=ytrue, ypred=ypred)
    labels = np.unique(y)
    perms = list(itertools.permutations(range(len(labels))))
    candidates = [{label: value for label, value in zip(labels, perm)} for perm in perms]

    test_df = pl.DataFrame(d)

    for idx, candidate in enumerate(candidates):
        test_df = test_df.with_columns(
            pl.col('ytrue').replace(candidate).cast(pl.UInt8).alias(f'C_{idx}')
        )
    
    best_match = (
        test_df
        .with_columns(
            (pl.all().exclude(['ytrue', 'ypred']) == pl.col('ypred'))
        )
        .select(pl.all().exclude(['ytrue', 'ypred']))
        .sum()
        .transpose(include_header=True, column_names=['true_count'])
        .sort(by='true_count', descending=True)
        .head(1)
    ).to_dicts()[0]
    
    best_acc = best_match['true_count']/len(ypred)
    best_label_idx = int(best_match['column'].split("_")[-1])
    best_label = candidates[best_label_idx]
    
    return best_label, best_acc

We also create a function to evaluate our model

In [ ]:
def plot_confusion_matrix(conf_matrix: np.ndarray, 
                          labels: list,
                          cmap: str = 'viridis',
                          title: str = 'Confusion Matrix',
                          figsize: tuple = (10, 10)) -> None:
    
    labels_vals = list(sorted(labels))
    x_sums = conf_matrix.sum(axis=0)
    y_sums = conf_matrix.sum(axis=1)
    xticks = [f'{x}\n{y}' for x, y in zip(labels_vals, x_sums)]
    yticks = [f'{x}\n{y}' for x, y in zip(labels_vals, y_sums)]
    plt.figure(figsize=figsize)
    plt.imshow(conf_matrix, cmap=cmap)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.xticks(range(conf_matrix.shape[1]), labels=xticks)
    plt.yticks(range(conf_matrix.shape[0]), labels=yticks)
    for i in range(conf_matrix.shape[0]):
        for j in range(conf_matrix.shape[1]):
            if conf_matrix[i, j] > 10000:
                color='black'
            else: 
                color='white'
            plt.text(j, 
                     i, 
                     conf_matrix[i, j], 
                     ha="center", 
                     va="center", 
                     color=color)

    plt.colorbar(fraction=0.046, pad=0.04)
    return plt

def evaluate_model(ytrue: list[str], 
                   ypred: list[int], 
                   labels: dict = None) -> None:
    
    if labels is None: 
        ypred_cat = ypred
    else:
        mapping_labels={v:k for k, v in labels.items()}
        ypred_cat = list(map(mapping_labels.get, ypred))
    
    report = classification_report(y_true=ytrue, y_pred=ypred_cat)
    conf_matrix = confusion_matrix(y_true=ytrue, y_pred=ypred_cat)
    
    print(report)
    return plot_confusion_matrix(conf_matrix, list(set(ypred_cat)))

3.1 Unsupervised: KMeans clustering

In [ ]:
kmeans_model = KMeans(n_clusters=7, n_init='auto', )
kmeans_model.fit(pca_X)
yhat_kmeans = kmeans_model.predict(pca_X)
In [ ]:
kmeans_labels, kmeans_acc = map_labels_pl(y, yhat_kmeans)
print(f'The accuracy for the model is {kmeans_acc}')
The accuracy for the model is 0.4955160713874345
In [ ]:
fig = evaluate_model(y, yhat_kmeans, kmeans_labels)
fig.savefig('kmeans.png', transparent=True)
              precision    recall  f1-score   support

           A       0.45      0.86      0.59    100000
           B       0.80      0.68      0.73     81711
           F       0.00      0.00      0.00    100000
           G       0.43      0.90      0.58    100000
           K       0.23      0.09      0.13    100000
           M       0.74      0.44      0.55     99999
           O       0.85      0.59      0.69     26016

    accuracy                           0.50    607726
   macro avg       0.50      0.51      0.47    607726
weighted avg       0.45      0.50      0.43    607726

The results for kmeans are not very good with the default parameters. The model fail to separate the stars from the F class and mainly classify them as G. We can try with another init value

In [ ]:
kmeans_model = KMeans(n_clusters=7, n_init='auto', init='random')
kmeans_model.fit(pca_X)
yhat_kmeans = kmeans_model.predict(pca_X)
In [ ]:
kmeans_labels, kmeans_acc = map_labels_pl(y, yhat_kmeans)
print(f'The accuracy for the model is {kmeans_acc}')
The accuracy for the model is 0.55177497753922
In [ ]:
evaluate_model(y, yhat_kmeans, kmeans_labels)
              precision    recall  f1-score   support

           A       0.52      0.91      0.66    100000
           B       0.79      0.77      0.78     81711
           F       0.00      0.00      0.00    100000
           G       0.47      0.88      0.61    100000
           K       0.47      0.34      0.40    100000
           M       0.77      0.44      0.56     99999
           O       0.84      0.61      0.70     26016

    accuracy                           0.55    607726
   macro avg       0.55      0.56      0.53    607726
weighted avg       0.51      0.55      0.50    607726

Out[ ]:
<module 'matplotlib.pyplot' from '/opt/conda/lib/python3.10/site-packages/matplotlib/pyplot.py'>

The results are similarly bad, Kmeans seems to be a bad choice for this case.

3.2 Unsupervised: KMeans clustering with over sampling

Here we will use SMOTE to over sample the dataset and run KMeans again.

In [ ]:
smote = SMOTE()

X_smote, y_smote = smote.fit_resample(pca_X, y)
In [ ]:
kmeans_model = KMeans(n_clusters=7, n_init='auto', )
kmeans_model.fit(X_smote)
yhat_kmeans = kmeans_model.predict(X_smote)
In [ ]:
kmeans_labels, kmeans_acc = map_labels_pl(y_smote, yhat_kmeans)
print(f'The accuracy for the model is {kmeans_acc}')
The accuracy for the model is 0.42214285714285715
In [ ]:
fig = evaluate_model(y_smote, yhat_kmeans, kmeans_labels)
fig.savefig('test.png', transparent=True)
              precision    recall  f1-score   support

           A       0.34      0.74      0.47    100000
           B       0.35      0.29      0.31    100000
           F       0.40      0.95      0.56    100000
           G       0.00      0.00      0.00    100000
           K       0.18      0.07      0.10    100000
           M       0.70      0.44      0.54    100000
           O       0.97      0.48      0.64    100000

    accuracy                           0.42    700000
   macro avg       0.42      0.42      0.37    700000
weighted avg       0.42      0.42      0.37    700000

3.3 Supervised: Decision Tree

Now let's try with a supervised model. For this first model I choosed a Decision Tree Claassifier. First we start with the normal dataset (without over sampling).

Our first task is to split the data into train/test datasets.

In [ ]:
X_train, X_test, y_train, y_test = train_test_split(pca_X, y, test_size = 0.3)
In [ ]:
tree = DecisionTreeClassifier()

tree.fit(X=X_train, y=y_train)
yhat_tree = tree.predict(X_test)
In [ ]:
fig = evaluate_model(y_test, yhat_tree)
fig.savefig('dt_base.png', transparent=True)
              precision    recall  f1-score   support

           A       0.89      0.89      0.89     29919
           B       0.95      0.95      0.95     24716
           F       0.63      0.63      0.63     30264
           G       0.56      0.56      0.56     29751
           K       0.70      0.70      0.70     30086
           M       0.85      0.85      0.85     29833
           O       0.91      0.92      0.91      7749

    accuracy                           0.76    182318
   macro avg       0.78      0.78      0.78    182318
weighted avg       0.76      0.76      0.76    182318

This model is already a lot better than the previous model. Let's run hyperparamater tunning on the model.

In [ ]:
parameters = {'criterion':['gini'],
              'max_depth':np.arange(5,16).tolist()[0::3],
              'min_samples_split':np.arange(2,9).tolist()[0::3],
              'max_leaf_nodes':np.arange(12,42).tolist()[0::4]}

grid = GridSearchCV(DecisionTreeClassifier(), 
                   parameters, 
                   cv=4, 
                   n_jobs=-1, 
                   )

grid.fit(X_train, y_train)
Out[ ]:
GridSearchCV(cv=4, estimator=DecisionTreeClassifier(), n_jobs=-1,
             param_grid={'criterion': ['gini'], 'max_depth': [5, 8, 11, 14],
                         'max_leaf_nodes': [12, 16, 20, 24, 28, 32, 36, 40],
                         'min_samples_split': [2, 5, 8]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=4, estimator=DecisionTreeClassifier(), n_jobs=-1,
             param_grid={'criterion': ['gini'], 'max_depth': [5, 8, 11, 14],
                         'max_leaf_nodes': [12, 16, 20, 24, 28, 32, 36, 40],
                         'min_samples_split': [2, 5, 8]})
DecisionTreeClassifier()
DecisionTreeClassifier()
In [ ]:
res = dict(
    score=grid.cv_results_['mean_test_score'], 
    time=grid.cv_results_['mean_fit_time'], 
    max_depth=grid.cv_results_['param_max_depth'].tolist(), 
    max_leaf_nodes=grid.cv_results_['param_max_leaf_nodes'].tolist(),
    min_samples_split=grid.cv_results_['param_min_samples_split'].tolist(),
)

results = pl.DataFrame(res)

results.sort('score', descending=True).head(10)
Out[ ]:
shape: (10, 5)
scoretimemax_depthmax_leaf_nodesmin_samples_split
f64f64i64i64i64
0.8048112.712168402
0.8048112.7300978405
0.8048112.4920398408
0.8048112.49297111402
0.8048112.79086111405
0.8048112.77714511408
0.8048112.50777314402
0.8048112.55284414405
0.8048112.43945814408
0.8029092.4756098362

Let's plot these results

In [ ]:
plot_df = (
    results
    .group_by(['max_depth', 'min_samples_split'])
    .all()
)
In [ ]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 10))

for row in plot_df.iter_rows(named=True):
    ax1.scatter(row['max_leaf_nodes'], 
                row['score'],
                s=8,
                label=f"max depth: {row['max_depth']}, min simples split: {row['min_samples_split']}")

ax1.set_xlabel('max leaf nodes')
ax1.set_ylabel("Score")
ax1.set_title("Score vs. max leaf nodes")

for row in plot_df.iter_rows(named=True):
    ax2.plot(row['max_leaf_nodes'], 
             row['time'],
             label=f"max depth: {row['max_depth']}, min simples split: {row['min_samples_split']}")
    
ax2.set_xlabel('max leaf nodes')
ax2.set_ylabel("Time")
ax2.set_title("Time vs. max leaf nodes")

lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
fig.legend(lines, labels, loc='upper left', bbox_to_anchor=(1.0, 1.0))

plt.tight_layout()
plt.savefig('hyper.png', transparent=True)
plt.show()

A lot of parameter couples have similar results in score, but the time are a bit different. To inspect that we can take only the best scores and sort them by time

In [ ]:
best = (
    results
    .filter(pl.col('score') == pl.col('score').max())
    .sort('time')
)
best
Out[ ]:
shape: (9, 5)
scoretimemax_depthmax_leaf_nodesmin_samples_split
f64f64i64i64i64
0.8048112.43945814408
0.8048112.4920398408
0.8048112.49297111402
0.8048112.50777314402
0.8048112.55284414405
0.8048112.712168402
0.8048112.7300978405
0.8048112.77714511408
0.8048112.79086111405

The best model in score and time is the model with max_depth = 14, max_leaf_nodes = 40 and min_samples_split = 8. Let's fit a model with these parameters

In [ ]:
params = best.head(1).drop(['score', 'time']).to_dicts()[0]
tree = DecisionTreeClassifier(**params)

tree.fit(X=X_train, y=y_train)
yhat_tree = tree.predict(X_test)
In [ ]:
evaluate_model(y_test, yhat_tree)
              precision    recall  f1-score   support

           A       0.89      0.93      0.91     29919
           B       0.96      0.89      0.93     24716
           F       0.69      0.71      0.70     30264
           G       0.66      0.62      0.64     29751
           K       0.76      0.77      0.76     30086
           M       0.90      0.87      0.88     29833
           O       0.78      0.96      0.86      7749

    accuracy                           0.80    182318
   macro avg       0.80      0.82      0.81    182318
weighted avg       0.80      0.80      0.80    182318

Out[ ]:
<module 'matplotlib.pyplot' from '/opt/conda/lib/python3.10/site-packages/matplotlib/pyplot.py'>

The results are slightly better then the original model.

3.4 Supervised: Decision Tree with over sampling

In [ ]:
X_train_s, X_test_s, y_train_s, y_test_s = train_test_split(X_smote, y_smote, test_size = 0.3)
In [ ]:
tree = DecisionTreeClassifier()

tree.fit(X=X_train_s, y=y_train_s)
yhat_tree = tree.predict(X_test_s)
In [ ]:
evaluate_model(y_test_s, yhat_tree)
              precision    recall  f1-score   support

           A       0.89      0.89      0.89     29930
           B       0.95      0.95      0.95     29702
           F       0.62      0.62      0.62     30274
           G       0.55      0.56      0.56     30033
           K       0.70      0.70      0.70     30105
           M       0.85      0.86      0.85     30105
           O       0.97      0.98      0.97     29851

    accuracy                           0.79    210000
   macro avg       0.79      0.79      0.79    210000
weighted avg       0.79      0.79      0.79    210000

Out[ ]:
<module 'matplotlib.pyplot' from '/opt/conda/lib/python3.10/site-packages/matplotlib/pyplot.py'>

In this case over sampling yield to a slightly worst result.

3.5 Supervised: Other methods

Random forest¶

In [ ]:
rfc=RandomForestClassifier()
rfc.fit(X_train,y_train)
yhat_rf = rfc.predict(X_test)
In [ ]:
fig = evaluate_model(y_test, yhat_rf)
fig.savefig('rf.png', transparent=True)
              precision    recall  f1-score   support

           A       0.91      0.94      0.92     29919
           B       0.97      0.96      0.96     24716
           F       0.71      0.70      0.70     30264
           G       0.65      0.65      0.65     29751
           K       0.77      0.79      0.78     30086
           M       0.91      0.87      0.89     29833
           O       0.93      0.95      0.94      7749

    accuracy                           0.82    182318
   macro avg       0.84      0.84      0.84    182318
weighted avg       0.82      0.82      0.82    182318

Random forrest without hyper parameters tunning returns decent results, close (but better) to the random classifier

Logistic regression¶

In [ ]:
lr=LogisticRegression(max_iter=10000)
lr.fit(X_train,y_train)
yhat_lr=lr.predict(X_test)
In [ ]:
fig = evaluate_model(y_test, yhat_lr)
fig.savefig('lr.png', transparent=True)
              precision    recall  f1-score   support

           A       0.89      0.94      0.91     29919
           B       0.88      0.92      0.90     24716
           F       0.73      0.59      0.65     30264
           G       0.57      0.74      0.65     29751
           K       0.72      0.66      0.69     30086
           M       0.88      0.82      0.85     29833
           O       0.78      0.68      0.73      7749

    accuracy                           0.77    182318
   macro avg       0.78      0.76      0.77    182318
weighted avg       0.78      0.77      0.77    182318

Logistic regression is good as well but worst then our first Supervised model.

4. Conclusion

This study initially focused on exploratory data analysis. The important point of this part was the distribution of each characteristic and the high dispersion. Then, the first modelling step was to perform a PCA reduction and keep only the first 3 principal components. The modelling sections highlighted the fact that the unsupervised approach (at least the one tested) does not fit this dataset well. The high variability of the data may be the cause of this problem. Supervised approaches were more successful.

Future work may include:

  • Fit other Unsupervised ML approaches;
  • Preprocess the features (features engineering);
  • Hyperparameters fine tunning for Random Forest classification and Logistic regression;

As an amateur astronomer, I will continue this project and perhaps reuse the dataset for a deep learning approach.

github repo

In [ ]: